import cdtools
from scipy.io import loadmat, savemat
import numpy as np
import torch as t
from helper_functions import (center_probe, center_probe_fourier,
                              split_dataset, fourier_pad)

#
# Some initial options and data loading
#

view = False # Whether or not to plot the results as the calculation proceeds
if view:
    from matplotlib import pyplot as plt

device = 'cuda' # The device to run the reconstruction on

data_filename = 'data/magnetic_ptycho_rcp.cxi'

dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(data_filename,
                                                    cut_zeros=True)

#
# This section runs the initial, full reconstruction. Because the probe
# reconstructions are marginal, even with the full dataset, we seed the
# 50%-data reconstructions for Fourier ring correlation with the probe
# recovered from the full data. This means we must do the full-data
# reconstruction first.
#


# We initialize the reconstruction with the result of the structural
# ptychography, but the probe will be propagated to account for the
# different plane of magnetic sample
initialization_file = 'results/single_shot_ptycho_full_rough.mat'

savefile = 'results/single_shot_magnetic_ptycho_rcp_full.mat'
print('To save in', savefile)

if view:
    dataset.inspect()
    
rough_results = loadmat(initialization_file)

# Next, we create a ptychography model from the dataset. Using
# probe_fourier_crop keeps the probe to the original resolution,
# even though the object's resolution is increased
model = cdtools.models.FancyPtycho.from_dataset(
    dataset,
    n_modes=3,
    propagation_distance=30e-6,# This is needed to set the probe_norm well
    translation_scale=10,
    probe_support_radius=177, # We need a rather tight support initially
    simulate_probe_translation=False,
    obj_view_crop=-200,
    units='um',
)
    
# We initialize with the probe from the rough reconstruction, centered
probe = t.as_tensor(rough_results['probe'])
centered_probe = center_probe(probe)

# We propagate the initial probe guess to the correct plane, 51 um upstream
wavelength = rough_results['wavelength'][0,0]
step = np.abs(rough_results['basis'][0,1])

prop = cdtools.tools.propagators.generate_angular_spectrum_propagator(
    centered_probe.shape[-2:], [step,step], wavelength, -51e-6
)
propagated_probe = cdtools.tools.propagators.near_field(
    centered_probe, prop,
)

# And we also crop the probe initialization in Fourier space, because we
# are working with cropped patterns with the magnetic data due to the lower
# achievable resolution
fourier_probe = cdtools.tools.propagators.far_field(propagated_probe)
fourier_probe = t.nn.functional.pad(fourier_probe, (-100,)*4)
propagated_probe = cdtools.tools.propagators.inverse_far_field(
    fourier_probe)


model.probe.data = propagated_probe / model.probe_norm

# This accounts (roughly) for the difference in total intensity between the
# magnetic and structural datasets
model.probe.data /= 4

model.probe.data *= model.probe_support

# We also initialize with the background from the rough reconstruction
bkg = t.clamp(t.as_tensor(rough_results['background']),0)

bkg = t.nn.functional.pad(bkg, (-100,)*4)
model.background.data = t.sqrt(t.as_tensor(t.sqrt(bkg)))
    
# we move to the correct device
model.to(device=device)
dataset.get_as(device=device)

if view:
    model.inspect(dataset)
        
# We start with two rounds of aggressive reconstructions, zeroing out the
# object each time, with the goal of recovering good probe positions before
# starting the final reconstruction
for i in range(2):
    model.obj.data[:] = 1
    
    model.probe.requires_grad=False
    model.weights.requires_grad=False
    for loss in model.Adam_optimize(15, dataset, lr=0.03, batch_size=25, schedule=False):
        print(model.report())
        if model.epoch % 10 == 0 and view:
            model.inspect(dataset)

    model.weights.requires_grad=True
    model.probe.requires_grad=True
    for loss in model.Adam_optimize(50, dataset, lr=0.01, batch_size=25, schedule=False):
        print(model.report())
        if model.epoch % 10 == 0 and view:
            model.inspect(dataset)

# And we finish the reconstruction with a final, long round including
# learning rate scheduling. We also remove the probe support constraint.

model.probe.requires_grad=True
model.probe_support[:] = 1
model.weights.requires_grad=True
for loss in model.Adam_optimize(500, dataset, lr=0.003, batch_size=50, schedule=True):
    print(model.report())
    if model.epoch % 25 == 0 and view:
        model.inspect(dataset)    
        
# We orthogonalize the resulting probes
model.tidy_probes()

# And save the results
savemat(savefile, model.save_results(dataset))

# Finally, we plot the results
if view:
    model.inspect(dataset)
    model.compare(dataset)
    

#
# Doing the 50/50 reconstructions for Fourier ring correlations
#


# Now we initialize with the full-data reconstructed results
initialization_file = 'results/single_shot_magnetic_ptycho_rcp_full.mat'

full = loadmat(initialization_file)

# We then load the recovered translations
dataset.translations = t.as_tensor(full['translations'])

# We split the dataset into two disjoint halves to calculate an FRC
dataset_1, dataset_2 = split_dataset(dataset)


for plan, dataset in [('half_1', dataset_1),
                      ('half_2', dataset_2)]:

    savefile = f'results/single_shot_magnetic_ptycho_rcp_{plan}.mat'
    print('To save in', savefile)

    if view:
        dataset.inspect()
        
    rough_results = loadmat(initialization_file)
    
    # Next, we create a ptychography model from the dataset. Using
    # probe_fourier_crop keeps the probe to the original resolution,
    # even though the object's resolution is increased
    model = cdtools.models.FancyPtycho.from_dataset(
        dataset,
        n_modes=3,
        propagation_distance=30e-6,# This is needed to set the probe_norm well
        translation_scale=10,
        simulate_probe_translation=False,
        obj_view_crop=-200,
        units='um',
    )
    
    # We initialize with the probe from the rough reconstruction, centered
    probe = t.as_tensor(rough_results['probe'])

    model.probe.data = probe / model.probe_norm
    
    # This accounts for a slow increase of probe intensity over the full
    # reconstruction (and corresponding decrease in object transmission).
    # It will also be normalized in the later analysis of the results, but
    # this keeps the learning rates more consistent to also roughly account
    # for it here
    model.probe.data /= 2

    # We also initialize with the background from the full reconstruction
    bkg = t.clamp(t.as_tensor(rough_results['background']),0)

    model.background.data = t.sqrt(t.as_tensor(t.sqrt(bkg)))
    
    # we move to the correct device
    model.to(device=device)
    dataset.get_as(device=device)

    if view:
        model.inspect(dataset)

    # Now we do a final round of object initialization without probe refinement
    model.probe.requires_grad=False
    model.weights.requires_grad=False
    for loss in model.Adam_optimize(15, dataset, lr=0.03, batch_size=25, schedule=False):
        print(model.report())
        if model.epoch % 5 == 0 and view:
            model.inspect(dataset)

    model.weights.requires_grad=True
    model.probe.requires_grad=True
    for loss in model.Adam_optimize(50, dataset, lr=0.01, batch_size=25, schedule=False):
        print(model.report())
        if model.epoch % 10 == 0 and view:
            model.inspect(dataset)

    # And we finish the reconstruction with a final, long round including
    # learning rate scheduling to really converge well
    model.probe.requires_grad=True
    model.weights.requires_grad=True
    for loss in model.Adam_optimize(500, dataset, lr=0.003, batch_size=50, schedule=True):
        print(model.report())
        if model.epoch % 25 == 0 and view:
            model.inspect(dataset)    

    # We orthogonalize the resulting probes
    model.tidy_probes()

    # And save the results
    savemat(savefile, model.save_results(dataset))

    # Finally, we plot the results
    if view:
        model.inspect(dataset)
        model.compare(dataset)
        

if view:
    plt.show()
